from utils_for_json import *
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback, TrainerState, TrainerControl
from datasets import load_dataset, DatasetDict, Dataset
from trl import SFTConfig, SFTTrainer
import argparse
import warnings
import warnings
from accelerate import Accelerator
from accelerate.utils import gather_object
from codebleu import calc_codebleu
import os
import torch.distributed as dist
from datetime import timedelta
import time
from transformers import DataCollatorWithPadding

# Ignore all warnings
warnings.filterwarnings("ignore")

os.environ['TORCH_NCCL_BLOCKING_WAIT'] = '1'
os.environ['TORCH_NCCL_ASYNC_ERROR_HANDLING'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = "False"

if os.getenv('PYCHARM_HOSTED') != '1':
    dist.init_process_group(backend='nccl', timeout=timedelta(hours=6))


# Initialize the Accelerator
accelerator = Accelerator(mixed_precision='bf16')

if accelerator.state.deepspeed_plugin:
    deepspeed_config = accelerator.state.deepspeed_plugin.deepspeed_config
    zero_version = deepspeed_config.get('zero_optimization', {}).get("stage")
    print(zero_version)
else:
    zero_version = -1



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", default="/Pretrained_Language_Models/Meta-Llama-3.1-8B-Instruct", type=str)
    parser.add_argument("--lr", default=5e-7, type=float)
    parser.add_argument("--epochs", default=1, type=int)


    args = parser.parse_args()

    # training_data = load_data()

    training_data = load_prune_data(path='./data/new_json_data.json')
    print('len:', len(training_data))
    # Convert the list of dicts to a Dataset
    training_dataset = Dataset.from_list(training_data)

    model_path = args.model_name
    print(model_path)
    output_path =  os.path.basename(model_path) + f"_JsonSft_Prune_{args.lr}_{args.epochs}"
    # sft_config = SFTConfig(
    #     output_dir='./output',
    #     packing=False,
    #     max_seq_length=4096  # Set a max sequence length
    # )


    if zero_version != 3:
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            use_cache=False,
            attn_implementation="flash_attention_2",
            device_map={"": accelerator.process_index},
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            use_cache=False,
            attn_implementation="flash_attention_2",
        )
    model.config.pretraining_tp = 1

    sft_config = SFTConfig(
        output_dir=output_path,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=1,
        gradient_checkpointing=True,
        optim="adamw_torch_fused",
        logging_steps=10,
        save_strategy='epoch',
        learning_rate=args.lr,
        bf16=True,
        tf32=True,
        max_grad_norm=0.3,
        warmup_ratio=0.1,
        lr_scheduler_type="linear",
        disable_tqdm=False,
        report_to="tensorboard",
        max_seq_length=4096  
    )

    trainer = SFTTrainer(
        model,
        args=sft_config,
        train_dataset=training_dataset,
    )

    trainer.train()